{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "toc": true
   },
   "source": [
    "<h1>Table of Contents<span class=\"tocSkip\"></span></h1>\n",
    "<div class=\"toc\"><ul class=\"toc-item\"><li><span><a href=\"#Load-Network\" data-toc-modified-id=\"Load-Network-1\"><span class=\"toc-item-num\">1&nbsp;&nbsp;</span>Load Network</a></span></li><li><span><a href=\"#Project\" data-toc-modified-id=\"Project-2\"><span class=\"toc-item-num\">2&nbsp;&nbsp;</span>Project</a></span></li><li><span><a href=\"#Encode\" data-toc-modified-id=\"Encode-3\"><span class=\"toc-item-num\">3&nbsp;&nbsp;</span>Encode</a></span></li><li><span><a href=\"#Generate-Images\" data-toc-modified-id=\"Generate-Images-4\"><span class=\"toc-item-num\">4&nbsp;&nbsp;</span>Generate Images</a></span></li><li><span><a href=\"#Projected-Latent-Initialization\" data-toc-modified-id=\"Projected-Latent-Initialization-5\"><span class=\"toc-item-num\">5&nbsp;&nbsp;</span>Projected Latent Initialization</a></span></li></ul></div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Project/embed real images into StyleGANv2 latent space."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import sys\n",
    "import os\n",
    "from datetime import datetime\n",
    "from tqdm import tqdm\n",
    "\n",
    "# ffmpeg installation location, for creating videos\n",
    "plt.rcParams['animation.ffmpeg_path'] = str(Path.home() / \"Documents/dev_tools/ffmpeg-20190623-ffa64a4-win64-static/bin/ffmpeg.exe\")\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "# StyleGAN Utils\n",
    "from stylegan_utils import load_network, gen_image_fun, synth_image_fun, create_video\n",
    "\n",
    "import dnnlib\n",
    "import dataset_tool\n",
    "import run_projector\n",
    "import projector\n",
    "import training.dataset\n",
    "import training.misc\n",
    "\n",
    "# Specific of encoder repos, comment out if not needed\n",
    "#from encoder.perceptual_model import PerceptualModel\n",
    "#from encoder.generator_model import Generator\n",
    "\n",
    "# Data Science Utils\n",
    "sys.path.append(os.path.join(os.pardir, 'data-science-learning'))\n",
    "\n",
    "from ds_utils import generative_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_dir = Path.home() / 'Documents/generated_data/stylegan'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS_DIR = Path(\"C:/Users/User/Documents/models/stylegan2\")\n",
    "MODEL_NAME = 'original_ffhq'\n",
    "SNAPSHOT_NAME = 'stylegan2-ffhq-config-f'\n",
    "\n",
    "Gs, Gs_kwargs, noise_vars = load_network(str(MODELS_DIR / MODEL_NAME / SNAPSHOT_NAME) + '.pkl')\n",
    "\n",
    "Z_SIZE = Gs.input_shape[1:][0]\n",
    "IMG_SIZE = Gs.output_shape[2:]\n",
    "IMG_SIZE"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Project"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def project_images(images_dir, tfrecord_dir, data_dir, num_steps, num_snapshots):\n",
    "    # setup projector\n",
    "    print('Setting up projector')\n",
    "    proj = projector.Projector(num_steps=num_steps)\n",
    "    proj.set_network(Gs)\n",
    "    \n",
    "    # generate tfrecords\n",
    "    nb_images = dataset_tool.create_from_images(str(tfrecord_dir), str(images_dir), True)\n",
    "\n",
    "    # loading images from tfrecords\n",
    "    dataset_obj = training.dataset.load_dataset(data_dir=str(data_dir), tfrecord_dir=tfrecord_dir, \n",
    "                                                max_label_size=0, verbose=True, repeat=False, shuffle_mb=0)\n",
    "    assert dataset_obj.shape == Gs.output_shape[1:]\n",
    "    \n",
    "    # project all loaded images\n",
    "    print('=======================')\n",
    "    for image_idx in tqdm(range(nb_images)):\n",
    "        print(f'Projecting image {image_idx}/{nb_images}')\n",
    "        \n",
    "        images, _labels = dataset_obj.get_minibatch_np(1)\n",
    "        images = training.misc.adjust_dynamic_range(images, [0, 255], [-1, 1])\n",
    "        \n",
    "        run_path = data_dir / f'out_{image_idx}'\n",
    "        run_path.mkdir()\n",
    "        run_projector.project_image(proj, targets=images, \n",
    "                                    png_prefix=dnnlib.make_run_dir_path(str(run_path / 'image_')), \n",
    "                                    num_snapshots=num_snapshots)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = res_dir / 'projection' / MODEL_NAME / SNAPSHOT_NAME / datetime.now().strftime(\"%Y%m%d_%H%M%S\") # where the projections results will be saved\n",
    "images_dir = Path.home() / 'Documents/generated_data/face_extract' / 'tmp_portraits'\n",
    "\n",
    "tfrecord_dir = data_dir / 'tfrecords'\n",
    "project_images(images_dir=images_dir, tfrecord_dir=tfrecord_dir, data_dir=data_dir, \n",
    "               num_steps=1000, num_snapshots=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "create_video(data_dir / 'out_7', \n",
    "             res_dir / 'projection' / 'out_{}.mp4'.format('test'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from config import Config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Config.get_inception_path()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Encode\n",
    "This does not use the official StyleGAN v2 projector, but instead relies on the direct encoder setup used by the community for v1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 1\n",
    "PERCEPTUAL_MODEL_IMG_SIZE = 256\n",
    "\n",
    "# setup utils generator and perceptual model\n",
    "generator = Generator(Gs, BATCH_SIZE, randomize_noise=False)\n",
    "perceptual_model = PerceptualModel(PERCEPTUAL_MODEL_IMG_SIZE, layer=9, batch_size=BATCH_SIZE)\n",
    "perceptual_model.build_perceptual_model(generator.generated_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_to_batches(l, n):\n",
    "    for i in range(0, len(l), n):\n",
    "        yield l[i:i + n]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_images(images_dir, data_dir, iterations, learning_rate=1.):\n",
    "    # collect images\n",
    "    images_paths = [str(img) for img in images_dir.glob('*')]\n",
    "    \n",
    "    GEN_IMAGES_DIR = data_dir / '{}'.format(iterations) / 'gen_images'\n",
    "    GEN_DLATENT_DIR = data_dir / '{}'.format(iterations) / 'latents'\n",
    "    GEN_IMAGES_DIR.mkdir(parents=True, exist_ok=True)\n",
    "    GEN_DLATENT_DIR.mkdir(parents=True, exist_ok=True)\n",
    "    \n",
    "    # project all loaded images\n",
    "    count = 0\n",
    "    for images_batch in tqdm(split_to_batches(images_paths, BATCH_SIZE), total=len(images_paths)//BATCH_SIZE):\n",
    "            images_names = [os.path.splitext(os.path.basename(img_path))[0] for img_path in images_batch]\n",
    "\n",
    "            perceptual_model.set_reference_images(images_batch)\n",
    "            optimizer = perceptual_model.optimize(generator.dlatent_variable, \n",
    "                                           iterations=iterations, \n",
    "                                           learning_rate=learning_rate)\n",
    "            pbar = tqdm(optimizer, leave=False, mininterval=9, total=iterations)\n",
    "            for loss in pbar:\n",
    "                pass\n",
    "                #pbar.set_description(' '.join(names)+' Loss: %.2f' % loss)\n",
    "            print(' '.join(images_names), ' loss:', loss)\n",
    "\n",
    "            # generate images from found dlatents and save them\n",
    "            generated_images = generator.generate_images()\n",
    "            generated_dlatents = generator.get_dlatents()\n",
    "            for img_array, dlatent, img_name in zip(generated_images, generated_dlatents, images_names):\n",
    "                img = Image.fromarray(img_array, 'RGB')\n",
    "                img.save(str(GEN_IMAGES_DIR / f'{img_name}.png'), 'PNG')\n",
    "                np.save(str(GEN_DLATENT_DIR / f'{img_name}.npy'), dlatent)\n",
    "\n",
    "            generator.reset_dlatents()\n",
    "            count += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate Images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_latents = np.random.rand(18, Z_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = gen_image_fun(Gs, target_latents, Gs_kwargs, noise_vars, truncation_psi=0.5)\n",
    "plt.imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = synth_image_fun(Gs, target_latents[np.newaxis,:,:], Gs_kwargs, randomize_noise=True)\n",
    "plt.imshow(img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Projected Latent Initialization\n",
    "Test network used to learn an initial mapping from an image to the intermediate StyleGAN latent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "from keras.models import load_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet = load_model(MODELS_DIR / MODEL_NAME / 'resnet' / 'finetuned_resnet.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet_img_size = (256, 256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_img = Image.open(\"\")\n",
    "target_img = target_img.resize(resnet_img_size)\n",
    "plt.imshow(target_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted_latent = resnet.predict(np.array(target_img)[np.newaxis,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = synth_image_fun(Gs, predicted_latent, Gs_kwargs, randomize_noise=True)\n",
    "plt.imshow(img)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "StyleGAN",
   "language": "python",
   "name": "stylegan"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
